{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "(sphx_glr_topic_vta_tutorials_frontend_deploy_classification.py)=\n",
        "# 在 VTA 上部署预训练的视觉模型\n",
        "\n",
        "**Author**: [Thierry Moreau](https://homes.cs.washington.edu/~moreau/)\n",
        "\n",
        "本教程提供了端到端的 demo，介绍了如何在 VTA 加速器设计上运行 ImageNet 分类推理来执行 ImageNet 分类任务。它将 Relay 展示为前端编译器，它可以执行量化（VTA 只支持 int8/32 推理）和 graph packing（以便在 core 中支持张量化），从而为硬件目标处理计算图。\n",
        "\n",
        "## 安装依赖\n",
        "\n",
        "现在回到 python 代码。导入包。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "tags": [
          "remove-cell"
        ]
      },
      "outputs": [],
      "source": [
        "import set_env"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os, time\n",
        "from PIL import Image\n",
        "import numpy as np\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "import tvm\n",
        "import logging\n",
        "from tvm.ir.transform import PassContext\n",
        "from tvm import rpc, autotvm, relay\n",
        "from tvm.contrib import graph_executor, utils, download\n",
        "# from tvm.contrib.debugger import debug_executor\n",
        "# from tvm.relay import transform\n",
        "\n",
        "import vta\n",
        "from vta.testing import simulator\n",
        "from vta.top import graph_pack\n",
        "\n",
        "# Make sure that TVM was compiled with RPC=1\n",
        "assert tvm.runtime.enabled(\"rpc\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 定义 platform\n",
        "\n",
        "在 CPU 和 VTA 上执行，并定义模型。\n",
        "\n",
        "从 [tvm/vta/vta_hw/config/vta_config.json](https://github.com/xinetzone/tvm/tree/dev/vta/vta_hw/config/vta_config.json) 文件加载 VTA 参数："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "env = vta.get_env()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "设定设备：\n",
        "\n",
        "1. 在 CPU 上推理，使用 ``device=arm_cpu``\n",
        "2. 在 FPGA 上推理，使用 ``device=vta``"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "ctx = \"vta\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "用于查找何时 start/end bit packing 的字典："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "pack_dict = {\n",
        "    \"resnet18_v1\": [\"nn.max_pool2d\", \"nn.global_avg_pool2d\"],\n",
        "    \"resnet34_v1\": [\"nn.max_pool2d\", \"nn.global_avg_pool2d\"],\n",
        "    \"resnet18_v2\": [\"nn.max_pool2d\", \"nn.global_avg_pool2d\"],\n",
        "    \"resnet34_v2\": [\"nn.max_pool2d\", \"nn.global_avg_pool2d\"],\n",
        "    \"resnet50_v2\": [\"nn.max_pool2d\", \"nn.global_avg_pool2d\"],\n",
        "    \"resnet101_v2\": [\"nn.max_pool2d\", \"nn.global_avg_pool2d\"],\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "```{note}\n",
        "``start_pack`` 和 ``stop_pack`` 标签指示从哪里开始和结束 graph packing relay pass：换句话说，从哪里开始和结束 VTA 卸载。\n",
        "```\n",
        "\n",
        "设定运行目标设备："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "env.target_vta_cpu"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "target = env.target if ctx == \"vta\" else env.target_vta_cpu"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 获取远程执行\n",
        "\n",
        "当 `env.TARGET` 为 `'pynq'` 时，重新配置 FPGA 和 runtime。否则，如果 `env.TARGET` 为 `'sim'`，则在本地执行。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "if env.TARGET not in [\"sim\", \"tsim\", \"intelfocl\"]:\n",
        "    # 如果设置环境变量，从 tracker 节点获取 remote。\n",
        "    # 要设置 tracker，您需要遵循“自动调优卷积网络用于 VTA ”教程。\n",
        "    tracker_host = os.environ.get(\"TVM_TRACKER_HOST\", None)\n",
        "    tracker_port = os.environ.get(\"TVM_TRACKER_PORT\", None)\n",
        "    # 否则，如果你有设备，你想直接从 host 编程，\n",
        "    # 确保你已经设置了下面的变量为你的板的 IP。\n",
        "    device_host = os.environ.get(\"VTA_RPC_HOST\", \"192.168.2.99\")\n",
        "    device_port = os.environ.get(\"VTA_RPC_PORT\", \"9091\")\n",
        "    if not tracker_host or not tracker_port:\n",
        "        remote = rpc.connect(device_host, int(device_port))\n",
        "    else:\n",
        "        remote = autotvm.measure.request_remote(\n",
        "            env.TARGET, \n",
        "            tracker_host, \n",
        "            int(tracker_port), \n",
        "            timeout=10000\n",
        "        )\n",
        "\n",
        "    # 重新配置 JIT 运行时和 FPGA。\n",
        "    # 通过将路径传递给 bitstream 文件而不是 None，\n",
        "    # 您可以使用自己的自定义 bitstream 编程 FPGA。\n",
        "    reconfig_start = time.time()\n",
        "    vta.reconfig_runtime(remote)\n",
        "    vta.program_fpga(remote, bitstream=None)\n",
        "    reconfig_time = time.time() - reconfig_start\n",
        "    print(f\"Reconfigured FPGA and RPC runtime in {reconfig_time:.2f}s!\")\n",
        "# 在仿真模式中，在本地托管 RPC 服务器。\n",
        "else:\n",
        "    remote = rpc.LocalSession()\n",
        "    if env.TARGET in [\"intelfocl\"]:\n",
        "        # program intelfocl aocx\n",
        "        vta.program_fpga(remote, bitstream=\"vta.bitstream\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "从远程获取执行上下文："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "if env.TARGET == \"intelfocl\":\n",
        "    ctxes = [remote.ext_dev(0), remote.cpu(0)]\n",
        "else:\n",
        "    # Graph runtime\n",
        "    ctxes = remote.ext_dev(0) if ctx == \"vta\" else remote.cpu(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 构建 graph executor 推理\n",
        "\n",
        "从 Gluon 模型动物园抓取视觉模型，用 Relay 编译。编译步骤如下：\n",
        "\n",
        "1. 将 MXNet 前端模块翻译为 Relay 模块。\n",
        "2. 应用 8-bit 量化：这里跳过了第一个 conv 层和 dense 层，这两个层都将在 CPU 上的 fp32 中执行。\n",
        "3. 执行  graph packing 来改变张量化的数据布局。\n",
        "4. 进行常数折叠以减少算子的数量（例如，消除 batch norm multiply）。\n",
        "5. 执行对 object 文件的 relay 构建。\n",
        "6. 将 object 文件加载到远程（FPGA 设备）。\n",
        "\n",
        "加载预配置的 AutoTVM 调度："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "ename": "AssertionError",
          "evalue": "",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
            "Cell \u001b[0;32mIn[14], line 28\u001b[0m\n\u001b[1;32m     26\u001b[0m         \u001b[38;5;28;01massert\u001b[39;00m env\u001b[38;5;241m.\u001b[39mBLOCK_IN \u001b[38;5;241m==\u001b[39m env\u001b[38;5;241m.\u001b[39mBLOCK_OUT\n\u001b[1;32m     27\u001b[0m         \u001b[38;5;66;03m# 如果目标是 intelfocl 或 sim，是否有 device annotation\u001b[39;00m\n\u001b[0;32m---> 28\u001b[0m         relay_prog \u001b[38;5;241m=\u001b[39m graph_pack(\n\u001b[1;32m     29\u001b[0m             mod[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m     30\u001b[0m             env\u001b[38;5;241m.\u001b[39mBATCH,\n\u001b[1;32m     31\u001b[0m             env\u001b[38;5;241m.\u001b[39mBLOCK_OUT,\n\u001b[1;32m     32\u001b[0m             env\u001b[38;5;241m.\u001b[39mWGT_WIDTH,\n\u001b[1;32m     33\u001b[0m             start_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnn.max_pool2d\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;66;03m#pack_dict[model][0],\u001b[39;00m\n\u001b[1;32m     34\u001b[0m             stop_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnn.global_avg_pool2d\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     35\u001b[0m             device_annot\u001b[38;5;241m=\u001b[39m(env\u001b[38;5;241m.\u001b[39mTARGET \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mintelfocl\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m     36\u001b[0m         )\n\u001b[1;32m     37\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     38\u001b[0m     relay_prog \u001b[38;5;241m=\u001b[39m mod[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
            "File \u001b[0;32m/media/pc/data/board/arria10/lxw/tasks/tvm-new/vta/python/vta/top/graphpack.py:606\u001b[0m, in \u001b[0;36mgraph_pack\u001b[0;34m(expr, bfactor, cfactor, weight_bits, start_name, stop_name, start_name_idx, stop_name_idx, count_meta, device_annot, annot_start_name, annot_end_name)\u001b[0m\n\u001b[1;32m    599\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(expr, relay\u001b[38;5;241m.\u001b[39mFunction)\n\u001b[1;32m    600\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m (\n\u001b[1;32m    601\u001b[0m     (start_name \u001b[38;5;241m!=\u001b[39m stop_name)\n\u001b[1;32m    602\u001b[0m     \u001b[38;5;129;01mor\u001b[39;00m (start_name_idx \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m!=\u001b[39m stop_name_idx \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m    603\u001b[0m     \u001b[38;5;129;01mor\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m (start_name_idx \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m stop_name_idx \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m))\n\u001b[1;32m    604\u001b[0m     \u001b[38;5;129;01mor\u001b[39;00m (start_name_idx \u001b[38;5;241m<\u001b[39m stop_name_idx)\n\u001b[1;32m    605\u001b[0m )\n\u001b[0;32m--> 606\u001b[0m expr \u001b[38;5;241m=\u001b[39m get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)\n\u001b[1;32m    607\u001b[0m expr \u001b[38;5;241m=\u001b[39m run_opt_pass(expr, transform\u001b[38;5;241m.\u001b[39mInferType())\n\u001b[1;32m    608\u001b[0m packer \u001b[38;5;241m=\u001b[39m ExprPack(bfactor, cfactor, weight_bits)\n",
            "File \u001b[0;32m/media/pc/data/board/arria10/lxw/tasks/tvm-new/vta/python/vta/top/graphpack.py:531\u001b[0m, in \u001b[0;36mget_subgraph\u001b[0;34m(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)\u001b[0m\n\u001b[1;32m    528\u001b[0m         \u001b[38;5;28;01massert\u001b[39;00m stop_found\n\u001b[1;32m    529\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m anf\n\u001b[0;32m--> 531\u001b[0m annotated \u001b[38;5;241m=\u001b[39m _recursion(anf, \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;28;01mFalse\u001b[39;00m, operator_current_idx)\n\u001b[1;32m    532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m run_opt_pass(annotated, transform\u001b[38;5;241m.\u001b[39mToGraphNormalForm())\n",
            "File \u001b[0;32m/media/pc/data/board/arria10/lxw/tasks/tvm-new/vta/python/vta/top/graphpack.py:494\u001b[0m, in \u001b[0;36mget_subgraph.<locals>._recursion\u001b[0;34m(anf, start_found, stop_found, operator_current_idx)\u001b[0m\n\u001b[1;32m    490\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Helper to obtain the subgraph.\"\"\"\u001b[39;00m\n\u001b[1;32m    491\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(anf, relay\u001b[38;5;241m.\u001b[39mFunction):\n\u001b[1;32m    492\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m relay\u001b[38;5;241m.\u001b[39mFunction(\n\u001b[1;32m    493\u001b[0m         anf\u001b[38;5;241m.\u001b[39mparams,\n\u001b[0;32m--> 494\u001b[0m         _recursion(anf\u001b[38;5;241m.\u001b[39mbody, start_found, stop_found, operator_current_idx),\n\u001b[1;32m    495\u001b[0m         anf\u001b[38;5;241m.\u001b[39mret_type,\n\u001b[1;32m    496\u001b[0m         anf\u001b[38;5;241m.\u001b[39mtype_params,\n\u001b[1;32m    497\u001b[0m         anf\u001b[38;5;241m.\u001b[39mattrs,\n\u001b[1;32m    498\u001b[0m     )\n\u001b[1;32m    499\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(anf, relay\u001b[38;5;241m.\u001b[39mexpr\u001b[38;5;241m.\u001b[39mLet):\n\u001b[1;32m    500\u001b[0m     value \u001b[38;5;241m=\u001b[39m anf\u001b[38;5;241m.\u001b[39mvalue\n",
            "File \u001b[0;32m/media/pc/data/board/arria10/lxw/tasks/tvm-new/vta/python/vta/top/graphpack.py:517\u001b[0m, in \u001b[0;36mget_subgraph.<locals>._recursion\u001b[0;34m(anf, start_found, stop_found, operator_current_idx)\u001b[0m\n\u001b[1;32m    511\u001b[0m operator_current_idx \u001b[38;5;241m=\u001b[39m _operator_idx_inc(value, count_meta, operator_current_idx)\n\u001b[1;32m    513\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    514\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m relay\u001b[38;5;241m.\u001b[39mexpr\u001b[38;5;241m.\u001b[39mLet(\n\u001b[1;32m    515\u001b[0m         anf\u001b[38;5;241m.\u001b[39mvar,\n\u001b[1;32m    516\u001b[0m         value,\n\u001b[0;32m--> 517\u001b[0m         _recursion(anf\u001b[38;5;241m.\u001b[39mbody, start_found, stop_found, operator_current_idx),\n\u001b[1;32m    518\u001b[0m     )\n\u001b[1;32m    519\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m BT:\n\u001b[1;32m    520\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m start_found\n",
            "File \u001b[0;32m/media/pc/data/board/arria10/lxw/tasks/tvm-new/vta/python/vta/top/graphpack.py:517\u001b[0m, in \u001b[0;36mget_subgraph.<locals>._recursion\u001b[0;34m(anf, start_found, stop_found, operator_current_idx)\u001b[0m\n\u001b[1;32m    511\u001b[0m operator_current_idx \u001b[38;5;241m=\u001b[39m _operator_idx_inc(value, count_meta, operator_current_idx)\n\u001b[1;32m    513\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    514\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m relay\u001b[38;5;241m.\u001b[39mexpr\u001b[38;5;241m.\u001b[39mLet(\n\u001b[1;32m    515\u001b[0m         anf\u001b[38;5;241m.\u001b[39mvar,\n\u001b[1;32m    516\u001b[0m         value,\n\u001b[0;32m--> 517\u001b[0m         _recursion(anf\u001b[38;5;241m.\u001b[39mbody, start_found, stop_found, operator_current_idx),\n\u001b[1;32m    518\u001b[0m     )\n\u001b[1;32m    519\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m BT:\n\u001b[1;32m    520\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m start_found\n",
            "    \u001b[0;31m[... skipping similar frames: get_subgraph.<locals>._recursion at line 517 (316 times)]\u001b[0m\n",
            "File \u001b[0;32m/media/pc/data/board/arria10/lxw/tasks/tvm-new/vta/python/vta/top/graphpack.py:517\u001b[0m, in \u001b[0;36mget_subgraph.<locals>._recursion\u001b[0;34m(anf, start_found, stop_found, operator_current_idx)\u001b[0m\n\u001b[1;32m    511\u001b[0m operator_current_idx \u001b[38;5;241m=\u001b[39m _operator_idx_inc(value, count_meta, operator_current_idx)\n\u001b[1;32m    513\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    514\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m relay\u001b[38;5;241m.\u001b[39mexpr\u001b[38;5;241m.\u001b[39mLet(\n\u001b[1;32m    515\u001b[0m         anf\u001b[38;5;241m.\u001b[39mvar,\n\u001b[1;32m    516\u001b[0m         value,\n\u001b[0;32m--> 517\u001b[0m         _recursion(anf\u001b[38;5;241m.\u001b[39mbody, start_found, stop_found, operator_current_idx),\n\u001b[1;32m    518\u001b[0m     )\n\u001b[1;32m    519\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m BT:\n\u001b[1;32m    520\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m start_found\n",
            "File \u001b[0;32m/media/pc/data/board/arria10/lxw/tasks/tvm-new/vta/python/vta/top/graphpack.py:528\u001b[0m, in \u001b[0;36mget_subgraph.<locals>._recursion\u001b[0;34m(anf, start_found, stop_found, operator_current_idx)\u001b[0m\n\u001b[1;32m    526\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    527\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m start_found\n\u001b[0;32m--> 528\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m stop_found\n\u001b[1;32m    529\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m anf\n",
            "\u001b[0;31mAssertionError\u001b[0m: "
          ]
        }
      ],
      "source": [
        "import torch\n",
        "from torchvision.models import resnet18, ResNet18_Weights\n",
        "\n",
        "# 为 ImageNet 分类器输入填充 shape 和数据类型字典\n",
        "dtype_dict = {\"data\": \"float32\"}\n",
        "input_shape = [1, 3, 224, 224]\n",
        "\n",
        "# PyTorch 模型，转换成 relay\n",
        "model = resnet18(weights=ResNet18_Weights.DEFAULT).eval()\n",
        "\n",
        "input_data = torch.randn(input_shape)\n",
        "scripted_model = torch.jit.trace(model, input_data).eval()\n",
        "with autotvm.tophub.context(target):\n",
        "    # 度量构建的开始时间\n",
        "    build_start = time.time()\n",
        "    # 开始前端编译\n",
        "    mod, params = relay.frontend.from_pytorch(scripted_model, [(\"data\", input_shape)])\n",
        "    if target.device_name == \"vta\":\n",
        "        # 在 Relay 中执行量化\n",
        "        # 注意：为了 fold batch norm，将 `opt_level` 设置为 `3`\n",
        "        with PassContext(opt_level=3):\n",
        "            with relay.quantize.qconfig(global_scale=8.0,\n",
        "                                        skip_conv_layers=[]):\n",
        "                mod = relay.quantize.quantize(mod, params=params)\n",
        "            # 对 VTA target 进行 graph packing 和 constant folding\n",
        "            assert env.BLOCK_IN == env.BLOCK_OUT\n",
        "            # 如果目标是 intelfocl 或 sim，是否有 device annotation\n",
        "            relay_prog = graph_pack(\n",
        "                mod[\"main\"],\n",
        "                env.BATCH,\n",
        "                env.BLOCK_OUT,\n",
        "                env.WGT_WIDTH,\n",
        "                start_name=\"nn.max_pool2d\", #pack_dict[model][0],\n",
        "                stop_name=\"nn.global_avg_pool2d\",\n",
        "                device_annot=(env.TARGET == \"intelfocl\"),\n",
        "            )\n",
        "    else:\n",
        "        relay_prog = mod[\"main\"]\n",
        "\n",
        "    # 禁用 AlterOpLayout，编译 Relay 程序\n",
        "    if target.device_name != \"vta\":\n",
        "        with PassContext(opt_level=3,\n",
        "                         disabled_pass={\"AlterOpLayout\"}):\n",
        "            lib = relay.build(\n",
        "                relay_prog,\n",
        "                target=target,\n",
        "                params=params\n",
        "            )\n",
        "    else:\n",
        "        if env.TARGET == \"intelfocl\":\n",
        "            # 在 CPU 和 VTA 上运行多个目标\n",
        "            target = {\"cpu\": env.target_vta_cpu,\n",
        "                      \"ext_dev\": target}\n",
        "        with vta.build_config(\n",
        "            opt_level=3,\n",
        "            disabled_pass={\"AlterOpLayout\",\n",
        "                           \"tir.CommonSubexprElimTIR\"}\n",
        "        ):\n",
        "            lib = relay.build(relay_prog,\n",
        "                              target=target,\n",
        "                              params=params)\n",
        "    # 度量 Relay 构建时间\n",
        "    build_time = time.time() - build_start\n",
        "    logging.info(f\"{model} inference graph built in {build_time:.2f}s!\")\n",
        "\n",
        "    # 将 inference library 发送到远程 RPC 服务器\n",
        "    temp = utils.tempdir()\n",
        "    lib.export_library(temp.relpath(\"graphlib.tar\"))\n",
        "    remote.upload(temp.relpath(\"graphlib.tar\"))\n",
        "    loaded_lib = remote.load_module(\"graphlib.tar\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 进行图像分类推理\n",
        "\n",
        "只需要下载 category 文件，`synset.txt` 和输入测试图像。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# 下载 ImageNet categories\n",
        "categ_url = \"https://github.com/uwsampl/web-data/raw/main/vta/models\"\n",
        "categ_fn = \"synset.txt\"\n",
        "download.download(f\"{categ_url}/{categ_fn}\", categ_fn)\n",
        "synset = eval(open(categ_fn).read())\n",
        "# 下载测试图片\n",
        "image_url = \"https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg\"\n",
        "image_fn = \"cat.png\"\n",
        "download.download(image_url, image_fn)\n",
        "# 准备用于推理的测试图像\n",
        "image = Image.open(image_fn).resize((224, 224))\n",
        "plt.imshow(image)\n",
        "plt.show()\n",
        "image = np.array(image) - np.array([123.0, 117.0, 104.0])\n",
        "image /= np.array([58.395, 57.12, 57.375])\n",
        "image = image.transpose((2, 0, 1))\n",
        "image = image[np.newaxis, :]\n",
        "image = np.repeat(image, env.BATCH, axis=0)\n",
        "\n",
        "# 生成图执行器（graph executor） `m`。\n",
        "m = graph_executor.GraphModule(loaded_lib[\"default\"](ctxes))\n",
        "# 设置网络参数和输入\n",
        "m.set_input(**params)\n",
        "m.set_input(\"data\", image)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 执行推理并收集执行统计信息\n",
        "\n",
        "```{tip}\n",
        "更多内容参考 {meth}`tvm.runtime.Module.time_evaluator`。\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "num = 4  # 为单个度量运行模块的次数\n",
        "rep = 3  # 测量的数量（由此得出 std dev）\n",
        "timer = m.module.time_evaluator(\"run\",\n",
        "                                ctxes,\n",
        "                                number=num,\n",
        "                                repeat=rep)\n",
        "\n",
        "if env.TARGET in [\"sim\", \"tsim\"]:\n",
        "    simulator.clear_stats()\n",
        "    timer()\n",
        "    sim_stats = simulator.stats()\n",
        "    print(\"\\nExecution statistics:\")\n",
        "    for k, v in sim_stats.items():\n",
        "        # 由于多次执行 workload，需要 normalize 统计数据。\n",
        "        # 注意，总是有一次 warm up 运行\n",
        "        # 因此，将整体统计数据除以 (num * rep + 1)\n",
        "        print(f\"\\t{k:<16}: {v // (num * rep + 1):>16}\")\n",
        "else:\n",
        "    tcost = timer()\n",
        "    std = np.std(tcost.results) * 1000\n",
        "    mean = tcost.mean * 1000\n",
        "    print(f\"\\nPerformed inference in {mean:.2f}ms (std = {std:.2f}) for {env.BATCH} samples\")\n",
        "    print(f\"Average per sample inference time: {mean / env.BATCH:.2f}ms\")\n",
        "\n",
        "# 得到的分类结果\n",
        "tvm_output = m.get_output(0, tvm.nd.empty(\n",
        "    (env.BATCH, 1000), \"float32\", remote.cpu(0)))\n",
        "for b in range(env.BATCH):\n",
        "    top_categories = np.argsort(tvm_output.numpy()[b])\n",
        "    # 报告 top-5 分类结果\n",
        "    print(f\"\\n{model} prediction for sample {b}\")\n",
        "    print(\"\\t#1:\", synset[top_categories[-1]])\n",
        "    print(\"\\t#2:\", synset[top_categories[-2]])\n",
        "    print(\"\\t#3:\", synset[top_categories[-3]])\n",
        "    print(\"\\t#4:\", synset[top_categories[-4]])\n",
        "    print(\"\\t#5:\", synset[top_categories[-5]])\n",
        "    # 这只是检查 5 个顶级类别之一是一种猫；\n",
        "    # 这绝不是对量化如何影响分类 accuracy 的准确评估，\n",
        "    # 而是旨在捕捉在 CI 中会影响 accuracy 的量化传递的变化。\n",
        "    cat_detected = False\n",
        "    for k in top_categories[-5:]:\n",
        "        if \"cat\" in synset[k]:\n",
        "            cat_detected = True\n",
        "    assert cat_detected"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "lib.ir_mod.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
